package com.example.socket.filter;

import com.example.socket.core.Session;
import com.example.socket.filter.firewall.ManagementMatcher;
import com.example.socket.utils.IpUtils;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler.Sharable;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.util.Attribute;
import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;

import javax.annotation.PostConstruct;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.Map.Entry;
import java.util.Set;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/**
 * 管理后台过滤器，用于为特定IP的访问设置管理后台标记
 */
@Sharable
public class ManagementFilter extends ChannelInboundHandlerAdapter implements ApplicationContextAware, FilterHandler {

    private ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
    private LinkedHashMap<Pattern, String> patterns = new LinkedHashMap<>();
    private ApplicationContext context;

    private Set<ManagementMatcher> matchers = new HashSet<>();

    @PostConstruct
    void init() {
        matchers.addAll(context.getBeansOfType(ManagementMatcher.class).values());
    }

    /**
     * 设置管理后台许可IP与对应的管理后台名称
     * @param allowIps key:许可IP的正则 value:名称
     */
    public void setAllowIps(LinkedHashMap<String, String> config) {
        Lock writeLock = lock.writeLock();
        try {
            writeLock.lock();

            patterns.clear();
            for (Entry<String, String> entry : config.entrySet()) {
                String ip = entry.getKey();
                String reg = ip.replace(".", "[.]").replace("*", "[0-9]*");
                Pattern pattern = Pattern.compile(reg);
                patterns.put(pattern, entry.getValue());
            }
        } finally {
            writeLock.unlock();
        }
    }

    /**
     * 设置管理后台许可IP与对应的管理后台名称
     * @param config 内容条目间用","分隔，IP和管理后台名称之间用"="分隔。范例格式:[IP]=[NAME],...
     */
    public void setAllowIpConfig(String config) {
        String[] ips = config.split(",");
        LinkedHashMap<String, String> result = new LinkedHashMap<String, String>(ips.length);
        for (String ip : ips) {
            String[] s = ip.split("=", 2);
            result.put(s[0], s[1]);
        }
        setAllowIps(result);
    }

    /**
     * 添加许可IP
     * @param ip 许可的IP
     * @param name 许可名
     */
    public void addAllowIp(String ip, String name) {
        Lock writeLock = lock.writeLock();
        writeLock.lock();
        try {
            String reg = ip.replace(".", "[.]").replace("*", "[0-9]*");
            Pattern pattern = Pattern.compile(reg);
            patterns.put(pattern, name);
        } finally {
            writeLock.unlock();
        }
    }

    @Override
    public void channelActive(ChannelHandlerContext ctx) throws Exception {
        Channel channel = ctx.channel();
        String ip = IpUtils.getIp(channel);
        boolean match = false;
        for (Entry<Pattern, String> entry : patterns.entrySet()) {
            Matcher matcher = entry.getKey().matcher(ip);
            if (matcher.matches()) {
                match = true;
                break;
            }
        }
        if (!match) {
            for (ManagementMatcher matcher : matchers) {
                match = matcher.match(ip);
                if (match) {
                    break;
                }
            }
        }
        if (match) {
            // 设置管理后台标记
            Attribute<Boolean> attr = channel.attr(Session.MANAGEMENT_KEY);
            attr.set(true);
        }
        super.channelActive(ctx);
    }

    @Override
    public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
        this.context = applicationContext;
    }

    @Override
    public int getOrder() {
        return 40;
    }

    @Override
    public String getName() {
        return "managementFilter";
    }
}
